from transformers import Wav2Vec2Model

class AudioTaggingModel(nn.Module):
    def __init__(self, num_tags=20):
        super().__init__()
        self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
        self.classifier = nn.Linear(768, num_tags)

    def forward(self, input_values):
        outputs = self.wav2vec(input_values).last_hidden_state
        return self.classifier(outputs.mean(dim=1))